package com.yugao.fintech.framework.assistant.core.tree;

import cn.hutool.core.lang.Filter;
import cn.hutool.core.util.ObjectUtil;
import com.yugao.fintech.framework.assistant.core.CollectionUtils;
import com.yugao.fintech.framework.assistant.core.exception.BizException;

import java.util.*;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;

/**
 * 树形结构工具类
 * 将一组list对象转成树形结构
 * 对象中一定要包含两个树形，pid 父id ， child 孩子集合
 */
public class TreeUtils {
    public static void main(String[] args) {
//        SystemInit systemMenu = new SystemInit();
////        systemMenu.setPid("234235443");
//        List<SystemInit> list = new ArrayList<>();
//        list.add(systemMenu);
//        toTree(list, SystemInit.class,"pid");
    }

    /**
     * 获取所有扁平的父节点
     *
     * @param flatList 扁平的集合
     * @param childIds 需要查找哪些子几点的父节点
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> listFlatParent(List<E> flatList, List<ID> childIds, Consumer<E> item) {
        Map<ID, E> nodeMap = flatList.stream().filter(e -> Objects.nonNull(e.treeId()))
                .collect(Collectors.toMap(TreeNode::treeId, Function.identity()));
        Map<ID, E> parentNodeMap = new HashMap<>();
        for (ID childId : childIds) {
            E childNode = nodeMap.get(childId);
            if (Objects.isNull(childNode)) {
                continue;
            }

            ID parentId = childNode.treeParentId();
            if (Objects.isNull(parentId)) {
                continue;
            }

            while (nodeMap.containsKey(parentId)) {
                E parent = nodeMap.get(parentId);
                parentId = parent.treeParentId();
                if (parentNodeMap.containsKey(parent.treeId())) {
                    continue;
                }
                parentNodeMap.put(parent.treeId(), parent);
                item.accept(parent);
            }
        }
        return new ArrayList<>(parentNodeMap.values());
    }

    /**
     * 获取所有扁平的孩子节点
     *
     * @param flatList  扁平的集合
     * @param parentIds 需要查找哪些父节点的子节点
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> listFlatChild(List<E> flatList, List<ID> parentIds, Consumer<E> item) {
        Map<ID, List<E>> parentNodeMap = flatList.stream().filter(e -> Objects.nonNull(e.treeParentId()))
                .collect(Collectors.groupingBy(TreeNode::treeParentId));
        List<E> childNodes = new ArrayList<>();
        for (ID parentId : parentIds) {
            recursionForeachChild(parentNodeMap, parentId, childNodes, item);
        }
        return childNodes;
    }

    public static <ID, E extends TreeNode<ID, E>> void recursionForeachChild(Map<ID, List<E>> parentNodeMap, ID parentId,
                                                                             List<E> childNodes, Consumer<E> callback) {
        List<E> childList = Optional.ofNullable(parentNodeMap.get(parentId)).orElse(Collections.emptyList());
        childList.forEach(node -> {
            recursionForeachChild(parentNodeMap, node.treeId(), childNodes, callback);
            if (Objects.nonNull(callback)) {
                callback.accept(node);
            }
            childNodes.add(node);
        });
    }

    /**
     * 拷贝树
     *
     * @param treeList 目标树
     * @param rootId   根节点id
     * @param startId  起始节点id
     * @param idGen    拷贝的树中node的id会通过调用该回调接口重新生成, arg1: 当前要重新生成id的原始对象
     */
    public static <ID, E extends TreeNode<ID, E>> E copyTree(List<E> treeList, ID rootId, ID startId, Function<E, ID> idGen) {
        if (CollectionUtils.isEmpty(treeList)) {
            return null;
        }
        Map<ID, E> map = new HashMap<>();
        foreachTree(treeList, node -> map.put(node.treeId(), node));

        List<E> list = new ArrayList<>();
        E rootNode = map.get(rootId);
        E startNode = map.get(startId);
        if (Objects.isNull(rootNode) || Objects.isNull(startNode)) {
            throw new RuntimeException(String.format("数据节点不存在, startId: %s, rootId: %s", startId, rootId));
        }
        // 重新生成起始node和结束node的id
        rootNode.treeId(idGen.apply(rootNode));
        startNode.treeId(idGen.apply(startNode));
        list.add(startNode);
        E itemNode;
        // 查找从起始节点到根节点的链路
        while ((itemNode = map.get(startNode.treeParentId())) != null && itemNode != rootNode) {
            itemNode.treeId(idGen.apply(itemNode));
            list.add(itemNode);
            startNode = itemNode;
        }

        // 构建起始节点到根节点的链路
        int length = list.size();
        E node = rootNode;
        for (int i = 0; i < length; i++) {
            E item = list.get(length - i - 1);
            node.children(Collections.singletonList(item));
            node = item;
            if (i > 0) {
                node.treeParentId(list.get(length - i).treeId());
            }
        }

        // 重新生成从staterId开始 到 叶子节点node的id
        List<E> startNodeChildTree = startNode.children();
        recursionUpdateChild(startNodeChildTree, idGen, startNode.treeId());
        return rootNode;
    }

    private static <ID, E extends TreeNode<ID, E>> void recursionUpdateChild(List<E> tree, Function<E, ID> idGen, ID parentId) {
        if (CollectionUtils.isEmpty(tree)) {
            return;
        }
        for (E e : tree) {
            ID id = idGen.apply(e);
            e.treeId(id);
            e.treeParentId(parentId);
            List<E> children = e.children();
            if (CollectionUtils.isNotEmpty(children)) {
                recursionUpdateChild(children, idGen, id);
            }
        }
    }


    /**
     * 递归遍历整个树
     *
     * @param treeItem 树的每个item
     */
    public static <ID, E extends TreeNode<ID, E>> void foreachTree(List<E> treeList, Consumer<E> treeItem) {
        if (CollectionUtils.isEmpty(treeList)) {
            return;
        }
        treeList.forEach(e -> {
            List<E> children = e.children();
            if (Objects.nonNull(treeItem)) {
                treeItem.accept(e);
            }
            if (CollectionUtils.isNotEmpty(children)) {
                foreachTree(e.children(), treeItem);
            }
        });
    }


    /**
     * 扁平化树
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> flatTree(List<E> treeList) {
        List<E> resp = new ArrayList<>();
        foreachTree(treeList, resp::add);
        resp.forEach(e -> e.children(null));
        return resp;
    }

    /**
     * 扁平化树
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> flatTree(E treeRoot) {
        List<E> resp = new ArrayList<>();
        List<E> treeList = new ArrayList<>();
        treeList.add(treeRoot);
        foreachTree(treeList, resp::add);
        resp.forEach(e -> e.children(null));
        return resp;
    }


    /**
     * 首字母转大写
     *
     * @param str 字符串
     */
    private static String firstCharToUpperCase(String str) {
        if (str == null) {
            return null;
        }
        if ("".equals(str)) {
            return "";
        }
        char[] cs = str.toCharArray();
        cs[0] -= 32;
        return String.valueOf(cs);
    }


    /**
     * 递归克隆当前节点（即克隆整个树，保留字段值）<br>
     * 注意，此方法只会克隆节点，节点属性如果是引用类型，不会克隆
     *
     * @return 新的节点
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> cloneTree(List<E> treeList) {
        List<E> result = new ArrayList<>(treeList.size());
        for (E node : treeList) {
            final E newNode = ObjectUtil.cloneByStream(node);
            if (Objects.nonNull(newNode) && CollectionUtils.isNotEmpty(newNode.children())) {
                newNode.children(cloneTree(node.children()));
            }
            result.add(newNode);
        }
        return result;
    }

    /**
     * 过滤树, 会先拷贝树
     *
     * @param treeList 树形结构
     * @param filter   过滤器, 返回 true 则满足条件, false 不满足条件(丢弃)
     * @param <ID>     ID
     * @param <E>      实体
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> filterNew(List<E> treeList, Filter<E> filter) {
        List<E> newTreeList = cloneTree(treeList);
        return filter(newTreeList, filter);
    }

    /**
     * 过滤树
     *
     * @param treeList 树形结构
     * @param filter   过滤器, 返回 true 则满足条件, false 不满足条件(丢弃)
     * @param <ID>     ID
     * @param <E>      实体
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> filter(List<E> treeList, Filter<E> filter) {

        // 标记要删除的node
        List<E> tagNodes = new ArrayList<>();

        for (E node : treeList) {
            if (Objects.nonNull(filter) && filter.accept(node)) {
                continue;
            }
            if (CollectionUtils.isNotEmpty(node.children())) {
                List<E> retNodes = filter(node.children(), filter);
                if (retNodes.isEmpty()) {
                    // 没有子节点情况
                    node.children(Collections.emptyList());
                    // 标记,循环结束后删除
                    tagNodes.add(node);
                }
            } else {
                // 标记,循环结束后删除
                tagNodes.add(node);
            }
        }
        treeList.removeAll(tagNodes);
        return treeList;
    }

    /**
     * 构建树型结构
     *
     * @param flatList 数据集合, 平坦结构
     * @param parentId 最顶层父id值 一般为 0 之类
     * @return 排序好的集合
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> build(List<E> flatList, ID parentId) {
        return build(flatList, parentId, null);
    }

    /**
     * 构建树型结构
     *
     * @param flatList 数据集合, 平坦结构
     * @param parentId 最顶层父id值 一般为 0 之类
     * @param callback 用于调用方二次处理
     * @return 排序好的集合
     */
    public static <ID, E extends TreeNode<ID, E>> List<E> build(List<E> flatList, ID parentId, Consumer<E> callback) {
        if (CollectionUtils.isEmpty(flatList)) {
            return Collections.emptyList();
        }

        if (Objects.isNull(parentId)) {
            throw new BizException("parentId is null");
        }
        List<E> returnList = new ArrayList<>();

        Map<ID, List<E>> parentNodeMap = flatList.stream().collect(Collectors.groupingBy(TreeNode::treeParentId));

        for (E node : flatList) {

            // 根节点
            if (parentId.equals(node.treeId())) {
                node.treeLevel(0);
                node.treeNames(node.name());
                if (Objects.nonNull(callback)) {
                    callback.accept(node);
                }
                continue;
            }

            // 一、根据传入的某个父节点ID,遍历该父节点的所有子节点
            if (Objects.nonNull(node.treeParentId()) && node.treeParentId().equals(parentId)) {
                recursionFn(parentNodeMap, node, 0, node.name(), String.valueOf(node.treeId()), callback);
                node.treeNames(node.name());
                node.treeParentIds(String.valueOf(parentId));
                returnList.add(node);
            }

        }
        returnList.sort(Comparator.comparing(e -> ObjectUtil.defaultIfNull(e.sortNo(), 1)));
        return returnList;
    }

    /**
     * 递归列表
     *
     * @param recursionCount 递归次数, 等同于树的级别
     * @param parentIds      当前节点的所有父id
     */
    private static <ID, E extends TreeNode<ID, E>> void recursionFn(Map<ID, List<E>> parentNodeMap,
                                                                    E currentNode,
                                                                    Integer recursionCount,
                                                                    String treeNames,
                                                                    String parentIds,
                                                                    Consumer<E> callback) {
        // 得到子节点列表
        List<E> childList = getChildNode(parentNodeMap, currentNode);
        childList.sort(Comparator.comparing(e -> ObjectUtil.defaultIfNull(e.sortNo(), 1)));
        currentNode.children(childList);
        currentNode.treeLeaf(CollectionUtils.isEmpty(currentNode.children()));
        currentNode.treeLevel(recursionCount);
        currentNode.treeNames(treeNames);
        currentNode.treeParentIds(parentIds);

        for (E node : childList) {
            String nextTreeNames = treeNames + "/" + node.name();
            String nextParentIds = String.valueOf(currentNode.treeId()).equals(parentIds)
                    ? parentIds : parentIds + "," + currentNode.treeId();
            if (hasChild(parentNodeMap, node)) {
                recursionFn(parentNodeMap, node, recursionCount + 1, nextTreeNames, nextParentIds, callback);
            } else {
                node.treeLeaf(true);
                node.treeLevel(recursionCount + 1);
                node.treeNames(nextTreeNames);
                node.treeParentIds(nextParentIds);
                if (Objects.nonNull(callback)) {
                    callback.accept(node);
                }
            }
        }
        if (Objects.nonNull(callback)) {
            callback.accept(currentNode);
        }
    }

    /**
     * 判断是否有子节点
     */
    private static <ID, E extends TreeNode<ID, E>> boolean hasChild(Map<ID, List<E>> parentNodeMap, E currentNode) {
        return getChildNode(parentNodeMap, currentNode).size() > 0;
    }

    /**
     * 得到子节点列表
     *
     * @param parentNodeMap 数据map
     * @param currentNode   当前节点
     */
    private static <ID, E extends TreeNode<ID, E>> List<E> getChildNode(Map<ID, List<E>> parentNodeMap, E currentNode) {
        List<E> allChild = Optional.ofNullable(parentNodeMap.get(currentNode.treeId())).orElse(Collections.emptyList());
        return new ArrayList<>(allChild);
    }
}
